#!/usr/bin/env python

import csv
import cv2
import sys
import signal
import threading
import numpy as np
from os.path import expanduser
from cv_bridge import CvBridge, CvBridgeError

import rospy
from sensor_msgs.msg import Image, JointState
from interbotix_xs_msgs.msg import JointGroupCommand
from interbotix_xs_msgs.srv import RobotInfo, TorqueEnable

from python_qt_binding.QtGui import QFont, QImage, QPixmap
from python_qt_binding.QtCore import Qt, QTimer
from python_qt_binding.QtWidgets import *

### Class that contains a GUI with parameters to perform object tracking based on color segmentation
class ColorTrackerGUI(QWidget):

    ### @brief Initialization of the ColorTrackerGUI class; sets up the GUI layout and callback functions
    def __init__(self):
        super(ColorTrackerGUI, self).__init__()
        self.cv_image = None                                                                                # Raw image in a format that OpenCV understands
        self.joint_states = None                                                                            # Current states of the 'pan' and 'tilt' joints
        self.js_mutex = threading.Lock()                                                                    # Mutex to prevent multiple functions from accessing 'self.joint_states' simultaneously
        self.image_mutex = threading.Lock()                                                                 # Mutex to prevent multiple functions from accessing 'self.cv_image' simultaneously
        self.bridge = CvBridge()                                                                            # Converts Sensor_msgs/Image to a format that OpenCV understands
        rospy.wait_for_service("torque_enable")                                                             # Wait for the 'torque_joints_off' ROS Service to become available
        rospy.wait_for_service("get_robot_info")                                                            # Wait for the 'get_robot_info' ROS Servie to become available
        self.srv_torque_enable = rospy.ServiceProxy("torque_enable", TorqueEnable)                          # ROS Service used to torque the 'pan' and 'tilt' joints off
        self.srv_robot_info = rospy.ServiceProxy("get_robot_info", RobotInfo)                               # ROS Service used to get joint limit information
        self.pub_cmds = rospy.Publisher("commands/joint_group", JointGroupCommand, queue_size=10)           # ROS Publisher to publish commands to the 'pan' and 'tilt' joints
        self.sub_img = rospy.Subscriber("lifecam/image_raw", Image, self.image_cb)                          # ROS Subscriber to get images from the 'usb_cam' node
        self.sub_joint_states = rospy.Subscriber("joint_states", JointState, self.joint_state_cb)           # ROS Subscriber to get the current joint states of the turret
        self.name_map = {}                                                                                  # Dictionary used to update label, slider, textbox, and button states
        self.big_font = QFont("Helvetica", 14, QFont.Bold)                                                  # Large font for Header text
        self.small_font = QFont("Helvetica", 9, QFont.Bold)                                                 # Small font for other text
        self.master_layout = QGridLayout()                                                                  # Root container for the GUI
        self.create_hsv_component("Hue Min", 179, 0, 0, 0)                                                  # Create GUI subsection to control the 'Hue Min' parameter
        self.create_hsv_component("Hue Max", 179, 179, 1, 0)                                                # Create GUI subsection to control the 'Hue Max' parameter
        self.create_horz_line(2, 0)                                                                         # Create horizontal divider between the 'Hue' and 'Saturation' sections
        self.create_hsv_component("Saturation Min", 255, 0, 3, 0)                                           # Create GUI subsection to control the 'Saturation Min' parameter
        self.create_hsv_component("Saturation Max", 255, 255, 4, 0)                                         # Create GUI subsection to control the 'Saturation Max' parameter
        self.create_horz_line(5, 0)                                                                         # Create horizontal divider between the 'Saturation' and 'Value' sections
        self.create_hsv_component("Value Min", 255, 0, 6, 0)                                                # Create GUI subsection to control the 'Value Min' parameter
        self.create_hsv_component("Value Max", 255, 255, 7, 0)                                              # Create GUI subsection to control the 'Value Max' parameter
        self.create_button_block(8, 0)                                                                      # Create GUI subsection for the buttons
        self.create_vert_line(0,1)                                                                          # Create vertical divider between the parameters and the image streams
        self.create_image_box(0, 2)                                                                         # Create GUI subsection containing the four image streams
        self.setLayout(self.master_layout)                                                                  # Set the layout of our custom widget to 'self.master_layout'
        while (self.cv_image is None): pass                                                                 # Wait until an initial image is recieved
        while (self.joint_states is None): pass                                                             # Wait until an initial joint state message is recieved
        self.joint_commands = JointGroupCommand("turret", list(self.joint_states.position))                 # Set initial joint commands to the current joint positions
        self.robot_info = self.srv_robot_info("group", "turret")                                            # Call the RobotInfo Service
        self.center_x = 320/2.0                                                                             # Middle pixel 'x' position
        self.center_y = 240/2.0                                                                             # Middle pixel 'y' position
        tmr_images = QTimer(self)                                                                           # Set up a QTimer to constantly update the images in the GUI subsection
        tmr_images.timeout.connect(self.tmr_images_cb)                                                      # Define the QTimer callback function
        tmr_images.start(33)                                                                                # Make the QTimer callback function execute every 33 milliseconds (~30 Hz) as that is the frame rate of the usb camera
        self.show()                                                                                         # Display the GUI

# ROS Subscribers

    ### @brief ROS Subscriber callback function to get updated joint states
    ### @param msg - updated JointState message
    def joint_state_cb(self, msg):
        with self.js_mutex:
            self.joint_states = msg

    ### @brief ROS Subscriber callback function to get the latest images from the 'usb_cam' node
    ### @param msg - updated Image message
    def image_cb(self, msg):
        with self.image_mutex:
            try:
              self.cv_image = self.bridge.imgmsg_to_cv2(msg, "bgr8")
            except CvBridgeError as e:
              print(e)

# PyQt Timer

    ### @brief A QTimer used to update the display for four image streams
    ### @details - The four image streams are:
    ###                1) Raw Image - the raw image from the camera
    ###                2) Mask Image - black/white image where 'black' represents the pixels we don't care about and 'white' represents the pixels we do care about
    ###                3) Processed Image - only dipslays the real part of the image corresponding to the 'white' part of the Mask Image
    ###                4) Overlay Image - the raw image overlaid with 'yellow' boundary points marking the contour of interest, and a 'red' point marking the centroid of that contour
    def tmr_images_cb(self):
        with self.image_mutex:
            cv_image = self.cv_image.copy()

        # Display the 'Raw Image'
        raw_image = QImage(cv_image, cv_image.shape[1], cv_image.shape[0], QImage.Format_RGB888).rgbSwapped()
        self.name_map["Raw Image"].setPixmap(QPixmap.fromImage(raw_image))

        # Create and Display the Black/White Image Mask
        hsv_image = cv2.cvtColor(cv_image, cv2.COLOR_BGR2HSV)
        lower_threshold = np.array([self.name_map["Hue Min"]["slider"].value(), self.name_map["Saturation Min"]["slider"].value(), self.name_map["Value Min"]["slider"].value()])
        upper_threshold = np.array([self.name_map["Hue Max"]["slider"].value(), self.name_map["Saturation Max"]["slider"].value(), self.name_map["Value Max"]["slider"].value()])
        cv_mask = cv2.inRange(hsv_image, lower_threshold, upper_threshold)
        # Using some Morphological operations to clean up 'noise' as explained at https://opencv-python-tutroals.readthedocs.io/en/latest/py_tutorials/py_imgproc/py_morphological_ops/py_morphological_ops.html#opening
        # Ellipsoid kernels seem to do the job
        kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(5,5))
        cv_mask = cv2.morphologyEx(cv_mask, cv2.MORPH_OPEN, kernel)
        cv_mask = cv2.morphologyEx(cv_mask, cv2.MORPH_DILATE, kernel)
        mask_image = QImage(cv_mask, cv_mask.shape[1], cv_mask.shape[0], QImage.Format_Grayscale8)
        self.name_map["Mask Image"].setPixmap(QPixmap.fromImage(mask_image))

        # Create and Display the Processed Image (only show the parts of the image that were not 'blacked out' in the previous step)
        hsv_processed = cv2.bitwise_and(hsv_image, hsv_image, mask=cv_mask)
        cv_processed = cv2.cvtColor(hsv_processed, cv2.COLOR_HSV2RGB)
        mask_image = QImage(cv_processed, cv_processed.shape[1], cv_processed.shape[0], QImage.Format_RGB888)
        self.name_map["Processed Image"].setPixmap(QPixmap.fromImage(mask_image))

        # Create and Display the Overlay Image (show the entire image overlaid with 'yellow' points that define the boundary of interest and a 'red' point that defines the centroid within that boundary)
        cv_overlay = hsv_image.copy()
        contours, _ = cv2.findContours(cv_mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)[-2:]
        # if a 'contour' (fancy word for boundary) was found...
        if len(contours) > 0:
            # return the largest contour from all the ones found
            c = max(contours, key = cv2.contourArea)
            # find the approximate center of that contour
            ((x,y), radius) = cv2.minEnclosingCircle(c)
            # draw a red circle at the center
            cv2.circle(cv_overlay,(int(x),int(y)),5,(0,255,255),-1)
            # draw yellow points to mark the boundary
            cv2.drawContours(cv_overlay,c,-1, (30,255,255),3)
            # Publish JointCommands so that the center pixel in the image overlaps the red circle mentioned above
            if (self.name_map["track"].text() == 'Stop Tracking'):
                if ((self.center_x - x) > 20):
                    self.joint_commands.cmd[0] += 0.02
                    if (self.joint_commands.cmd[0] > self.robot_info.joint_upper_limits[0]):
                        self.joint_commands.cmd[0] = self.robot_info.joint_upper_limits[0]
                elif ((self.center_x - x) < -20):
                    self.joint_commands.cmd[0] -= 0.02
                    if (self.joint_commands.cmd[0] < self.robot_info.joint_lower_limits[0]):
                        self.joint_commands.cmd[0] = self.robot_info.joint_lower_limits[0]
                if ((self.center_y - y) > 20):
                    self.joint_commands.cmd[1] -= 0.02
                    if (self.joint_commands.cmd[1] < self.robot_info.joint_lower_limits[1]):
                        self.joint_commands.cmd[1] = self.robot_info.joint_lower_limits[1]
                elif ((self.center_y - y) < -20):
                    self.joint_commands.cmd[1] += 0.02
                    if (self.joint_commands.cmd[1] > self.robot_info.joint_upper_limits[1]):
                        self.joint_commands.cmd[1] = self.robot_info.joint_upper_limits[1]
                self.pub_cmds.publish(self.joint_commands)

        cv_overlay = cv2.cvtColor(cv_overlay, cv2.COLOR_HSV2RGB)
        overlay_image = QImage(cv_overlay, cv_overlay.shape[1], cv_overlay.shape[0], QImage.Format_RGB888)
        self.name_map["Overlay Image"].setPixmap(QPixmap.fromImage(overlay_image))

# Helper functions to build individual sections of the main GUI

    ### @brief Creates a GUI subsection for the Hue, Saturation, And Value parameters
    ### @param name - 'Hue', 'Saturation', or 'Value'
    ### @param max - maximum value that the Slider should have
    ### @param default - default value that the Slider should have
    ### @param row - row placement of the GUI subsection in the master grid layout
    ### @param col - column placement of the GUI subsection in the master grid layout
    def create_hsv_component(self, name, max, default, row, col):
        # parent container for this section
        grid_layout = QGridLayout()

        # first row
        label = self.create_label(name, self.big_font)
        display = self.create_textbox(str(default))
        grid_layout.addWidget(label, 0, 1, Qt.AlignCenter)
        grid_layout.addWidget(display, 0, 1, Qt.AlignRight)

        # second row
        slider = self.create_slider(max, default)
        min_label = self.create_label("0", self.small_font)
        max_label = self.create_label(str(max), self.small_font)
        grid_layout.addWidget(min_label, 1, 0, Qt.AlignCenter)
        grid_layout.addWidget(slider, 1, 1, Qt.AlignCenter)
        grid_layout.addWidget(max_label, 1, 2, Qt.AlignCenter)

        # signals
        display.editingFinished.connect(lambda:self.update_slider_bar(name))
        slider.valueChanged.connect(lambda:self.update_display(name))

        # global dictionary to store and retrieve values
        self.name_map[name] = {'display' : display, 'slider' : slider, 'max' : max}

        # add the layout to the master layout
        self.master_layout.addLayout(grid_layout, row, col)

    ### @brief Creates a GUI subsection for the four image streams
    ### @param row - row placement of the GUI subsection in the master grid layout
    ### @param col - column placement of the GUI subsection in the master grid layout
    def create_image_box(self, row, col):
        # parent container for this section
        grid_layout = QGridLayout()
        image_types = ["Raw Image", "Mask Image", "Processed Image", "Overlay Image"]
        image_positions = [[0, 0], [0, 1], [1, 0], [1, 1]]

        for x in range(len(image_types)):
            # subcontainer
            vertical_layout = QVBoxLayout()
            # image stream title
            label = self.create_label(image_types[x], self.big_font)
            label.setAlignment(Qt.AlignCenter)
            # image stream 'window'
            pic_label = QLabel()
            pic_label.setFixedSize(320, 240)
            vertical_layout.addWidget(label)
            vertical_layout.addWidget(pic_label)
            grid_layout.addLayout(vertical_layout, image_positions[x][0], image_positions[x][1])
            self.name_map[image_types[x]] = pic_label

        # add the layout to the master layout
        self.master_layout.addLayout(grid_layout, row, col, 9, 1)

    ### @brief Creates a GUI subsection for the 'Load Configs', 'Save Configs', and 'Start Tracking' buttons
    ### @param row - row placement of the GUI subsection in the master grid layout
    ### @param col - column placement of the GUI subsection in the master grid layout
    def create_button_block(self, row, col):
        # parent container for this section
        horz_layout = QHBoxLayout()

        # first row
        load_button = self.create_button('Load Configs', self.small_font)
        save_button = self.create_button('Save Configs', self.small_font)
        track_button = self.create_button('Start Tracking', self.small_font)
        track_button.setStyleSheet('QPushButton {color: green;}')
        self.srv_torque_enable("group", "turret", False)

        horz_layout.addStretch()
        horz_layout.addWidget(load_button)
        horz_layout.addStretch()
        horz_layout.addWidget(save_button)
        horz_layout.addStretch()
        horz_layout.addWidget(track_button)
        horz_layout.addStretch()

        # signals
        load_button.clicked.connect(self.load_configs)
        save_button.clicked.connect(self.save_configs)
        track_button.clicked.connect(self.track)
        self.name_map["track"] = track_button

        self.master_layout.addLayout(horz_layout, row, col)

# Helper functions to build small GUI componenets

    ### @brief Create a QLabel with some custom settings
    ### @param text - message that the label should display
    ### @param font - type of font to use
    ### @param label [out] - returns QLabel object
    def create_label(self, text, font):
        label = QLabel(text)
        label.setFont(font)
        return label

    ### @brief Create a QPushButton with some custom settings
    ### @param text - message that the button should display
    ### @param font - type of font to use
    ### @param button [out] - returns QPushButton object
    def create_button(self, text, font):
        button = QPushButton(text)
        button.setFont(font)
        return button

    ### @brief Creates a QLineEdit box with some custom settings
    ### @param text - text that should be displayed
    ### @param textbox [out] - returns QLineEdit object
    def create_textbox(self, text):
        textbox = QLineEdit(text)
        textbox.setFont(self.small_font)
        textbox.setFixedWidth(50)
        textbox.setAlignment(Qt.AlignCenter)
        return textbox

    ### @brief Creates a QSlider with some custom settings
    ### @param max - max value of the slider
    ### @param default - default value of the slider
    ### @param slider [out] - returns QSlider object
    def create_slider(self, max, default):
        slider = QSlider(Qt.Horizontal)
        slider.setRange(0, max)
        slider.setValue(default)
        slider.setFixedWidth(450)
        return slider

    ### @brief Creates a horizontal line to act as a section divider
    ### @param row - row placement of the GUI subsection in the master grid layout
    ### @param col - column placement of the GUI subsection in the master grid layout
    def create_horz_line(self, row, col):
        horz_line = QLabel()
        horz_line.setFixedWidth(500)
        horz_line.setFrameStyle(QFrame.HLine)
        horz_line.setFrameShadow(QFrame.Sunken)
        horz_line.setLineWidth(2)
        horz_line.setMidLineWidth(2)
        self.master_layout.addWidget(horz_line, row, col, Qt.AlignCenter)

    ### @brief Creates a vertical line to act as a section divider
    ### @param row - row placement of the GUI subsection in the master grid layout
    ### @param col - column placement of the GUI subsection in the master grid layout
    def create_vert_line(self, row, col):
        vert_line = QLabel()
        vert_line.setFixedHeight(500)
        vert_line.setFrameStyle(QFrame.VLine)
        vert_line.setFrameShadow(QFrame.Sunken)
        vert_line.setLineWidth(2)
        vert_line.setMidLineWidth(2)
        self.master_layout.addWidget(vert_line, row, col, 9, 1, Qt.AlignCenter)

# Textbox callbacks

    ### @brief Event handler when a display is changed
    ### @param name - name of the slider group to which the 'slider' belongs ('Hue Min', 'Saturation Max', etc...)
    ### @details - the function updates the slider position to reflect that shown in the display
    def update_slider_bar(self, name):
        info = self.name_map[name]
        max = info['max']
        value = int(info['display'].text())
        if (value < 0):
            value = 0
            info['display'].setText(str(value))
        elif (value > max):
            value = max
            info['display'].setText(str(value))
        info['slider'].setValue(value)

# Slider callbacks

    ### @brief Event handler when a slider is changed
    ### @param name - name of the slider group to which the 'display' belongs ('Hue Min', 'Saturation Max', etc...)
    def update_display(self, name):
        info = self.name_map[name]
        slider_value = info['slider'].value()
        info['display'].setText(str(slider_value))

# Button callbacks

    ### @brief Event handler when the 'Load Configs' button is pressed
    ### @details - updates the sliders and textboxes in the GUI based on the HSV configs in the loaded file
    def load_configs(self):
        fname = QFileDialog.getOpenFileName(self, 'Open file', expanduser("~"), "CSV files (*.csv)")
        if (fname[0] == ""):
            return
        with open(fname[0]) as csv_file:
            csv_reader = csv.reader(csv_file)
            for row in csv_reader:
                self.name_map[row[0]]['slider'].setValue(int(row[1]))

    ### @brief Event handler when the 'Save Configs' button is pressed
    ### @details - saves the HSV configs to a CSV file
    def save_configs(self):
        fname = QFileDialog.getSaveFileName(self, 'Save file', expanduser("~"), "CSV files (*.csv)")
        if (fname[0] == ""):
            return
        config_list = ["Hue Min", "Hue Max", "Saturation Min", "Saturation Max", "Value Min", "Value Max"]
        with open (fname[0], mode="w") as csv_file:
            csv_writer = csv.writer(csv_file)
            for config in config_list:
                csv_writer.writerow([config, self.name_map[config]['slider'].value()])

    ### @brief Event handler when the 'Start/Stop Tracking' button is pressed
    ### @details - starts or stops the motors from actuating based on image feedback
    def track(self):
        if (self.name_map["track"].text() == "Start Tracking"):
            self.name_map["track"].setText("Stop Tracking")
            self.name_map["track"].setStyleSheet('QPushButton {color: red;}')
            self.srv_torque_enable("group", "turret", True)
            with self.js_mutex:
                self.joint_commands.cmd = list(self.joint_states.position)
        else:
            self.name_map["track"].setText("Start Tracking")
            self.name_map["track"].setStyleSheet('QPushButton {color: green;}')
            self.srv_torque_enable("group", "turret", False)

if __name__ == '__main__':
    rospy.init_node('xsturret_color_tracker_calibrator')
    app = QApplication(sys.argv)
    color_tracker_gui = ColorTrackerGUI()
    # Only kill the program at node shutdown
    signal.signal(signal.SIGINT, signal.SIG_DFL)
    sys.exit(app.exec_())
